{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AWQ on LLaVA"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we use LLaVA model to demonstrate the performance of AWQ on multi-modal models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to run this notebook, you need to install the following packages:\n",
    "- [AWQ](https://github.com/mit-han-lab/llm-awq)\n",
    "- [Pytorch](https://pytorch.org/)\n",
    "- [Accelerate](https://github.com/huggingface/accelerate)\n",
    "- [LLaVA](https://github.com/haotian-liu/LLaVA)\n",
    "- [Transformers](https://github.com/huggingface/transformers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jilin/anaconda3/envs/llava/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import requests\n",
    "from PIL import Image\n",
    "from io import BytesIO\n",
    "from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n",
    "from transformers import AutoTokenizer, CLIPVisionModel, CLIPImageProcessor, logging\n",
    "logging.set_verbosity_error()  # too many warnings\n",
    "from llava.conversation import conv_templates, SeparatorStyle\n",
    "from llava.utils import disable_torch_init\n",
    "from llava.model import *\n",
    "from llava.model.utils import KeywordsStoppingCriteria\n",
    "from awq.quantize.pre_quant import apply_awq\n",
    "from awq.quantize.quantizer import real_quantize_model_weight\n",
    "import os\n",
    "import gc\n",
    "\n",
    "# This demo only support single GPU for now\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "DEFAULT_IMAGE_TOKEN = \"<image>\"\n",
    "DEFAULT_IMAGE_PATCH_TOKEN = \"<im_patch>\"\n",
    "DEFAULT_IM_START_TOKEN = \"<im_start>\"\n",
    "DEFAULT_IM_END_TOKEN = \"<im_end>\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please get the LLaVA model from [LLaVA](https://github.com/haotian-liu/LLaVA) and run the following cell to generate a quantized model checkpoint first (note that we only quantize the language decoder, which dominates the model parameters). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:09<00:00,  3.14s/it]\n",
      "real weight quantization...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [09:07<00:00, 13.69s/it]\n"
     ]
    }
   ],
   "source": [
    "model_path = \"/dataset/llava/LLaVA-13B-v0\"  # Please change here \n",
    "quant_path = \"../quant_cache/LLaVA-13B-v0-w4-g128-awq.pt\"  # place to dump quant weights\n",
    "\n",
    "model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()\n",
    "\n",
    "awq_results = torch.load(\"../awq_cache/llava-13b-v0-w4-g128.pt\", map_location=\"cpu\")\n",
    "apply_awq(model, awq_results)\n",
    "\n",
    "real_quantize_model_weight(model, w_bit=4, q_config={\"zero_point\": True, \"q_group_size\": 128})\n",
    "torch.save(model.cpu().state_dict(), quant_path)\n",
    "\n",
    "del model\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then input a image link and a question below.\n",
    "\n",
    "![](https://llava.hliu.cc/file=/nobackup/haotian/code/LLaVA/llava/serve/examples/extreme_ironing.jpg)\n",
    "\n",
    "## Q: What is unusual about this image?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"What is unusual about this image?\"\n",
    "image_file = \"https://llava.hliu.cc/file=/nobackup/haotian/code/LLaVA/llava/serve/examples/extreme_ironing.jpg\" "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.17s/it]\n",
      "real weight quantization...(init only): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:37<00:00,  1.08it/s]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "disable_torch_init()\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
    "config = LlavaConfig.from_pretrained(model_path)\n",
    "with init_empty_weights():\n",
    "    model = LlavaLlamaForCausalLM.from_pretrained(model_path, config=config,\n",
    "                                                    torch_dtype=torch.float16, device_map=\"auto\")\n",
    "q_config = {\"zero_point\": True, \"q_group_size\": 128}\n",
    "real_quantize_model_weight(\n",
    "    model, w_bit=4, q_config=q_config, init_only=True)\n",
    "\n",
    "model = load_checkpoint_and_dispatch(\n",
    "    model, quant_path, device_map=\"auto\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jilin/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py:1211: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The unusual aspect of this image is that a man is standing on a portable ironing board in the middle of the road, ironing clothes while traffic, including a yellow taxi, moves around him. This is not a typical scene you would expect to see in a city, as ironing is usually done in a private setting like a home, and not on the street amidst traffic. It brings attention to the unconventional and unexpected nature of the situation.\n"
     ]
    }
   ],
   "source": [
    "def load_image(image_file):\n",
    "    if image_file.startswith('http') or image_file.startswith('https'):\n",
    "        response = requests.get(image_file)\n",
    "        image = Image.open(BytesIO(response.content)).convert('RGB')\n",
    "    else:\n",
    "        image = Image.open(image_file).convert('RGB')\n",
    "    return image\n",
    "\n",
    "image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)\n",
    "\n",
    "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
    "tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
    "if mm_use_im_start_end:\n",
    "    tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
    "\n",
    "vision_tower = model.get_model().vision_tower[0]\n",
    "if vision_tower.device.type == 'meta':\n",
    "    vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()\n",
    "    model.get_model().vision_tower[0] = vision_tower\n",
    "else:\n",
    "    vision_tower.to(device='cuda', dtype=torch.float16)\n",
    "vision_config = vision_tower.config\n",
    "vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]\n",
    "vision_config.use_im_start_end = mm_use_im_start_end\n",
    "if mm_use_im_start_end:\n",
    "    vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])\n",
    "image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2\n",
    "\n",
    "qs = query\n",
    "if mm_use_im_start_end:\n",
    "    qs = qs + '\\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN\n",
    "else:\n",
    "    qs = qs + '\\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len\n",
    "\n",
    "conv_mode = \"multimodal\"\n",
    "\n",
    "conv = conv_templates[conv_mode].copy()\n",
    "conv.append_message(conv.roles[0], qs)\n",
    "conv.append_message(conv.roles[1], None)\n",
    "prompt = conv.get_prompt()\n",
    "inputs = tokenizer([prompt])\n",
    "\n",
    "image = load_image(image_file)\n",
    "image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n",
    "\n",
    "input_ids = torch.as_tensor(inputs.input_ids).cuda()\n",
    "\n",
    "stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
    "keywords = [stop_str]\n",
    "stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
    "\n",
    "with torch.inference_mode():\n",
    "    output_ids = model.generate(\n",
    "        input_ids,\n",
    "        images=image_tensor.unsqueeze(0).half().cuda(),\n",
    "        do_sample=True,\n",
    "        temperature=0.2,\n",
    "        max_new_tokens=1024,\n",
    "        stopping_criteria=[stopping_criteria])\n",
    "\n",
    "input_token_len = input_ids.shape[1]\n",
    "n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
    "if n_diff_input_output > 0:\n",
    "    print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
    "outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
    "outputs = outputs.strip()\n",
    "if outputs.endswith(stop_str):\n",
    "    outputs = outputs[:-len(stop_str)]\n",
    "outputs = outputs.strip()\n",
    "print(outputs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
